#include "client.h"
#include"util.h"
#include"aes.h"
#include"sha1.h"
#include"residue_math.h"

#include<cstdlib>
#include<cstring>
#include<assert.h>
#include<fstream>
#include<iostream>
#include"my_exceptions.h"
void Client::set_mpk(const ZZ & n)
{
	mpk = n;
}
void Client::set_pk(const ZZ & n)
{
	pk = n;
}
void Client::set_sk(const ZZ  & n)
{
	sk = n;
}

ZZ Client::get_mpk()
{
	return mpk;
}

ZZ Client::get_pk()
{
	return pk;
}

ZZ Client::get_sk()
{
	return sk;
}

void Client::load_mpk(const string &  mpk_file_name)
{
	std::ifstream mpk_file(mpk_file_name.c_str());
	if (! mpk_file)
	{
		throw FileNotFoundException(mpk_file_name);
	}
	mpk_file >> mpk;
	mpk_file.close();
}

void Client::load_pk( const string  & pk_file_name)
{
	std::ifstream pk_file(pk_file_name.c_str());
	if (! pk_file)
	{
		throw FileNotFoundException(pk_file_name);
	}
	std::string identity;

	pk_file >> identity;
	const char * s = identity.c_str();
	pk = hash(s, mpk);
	pk_file.close();
}

void Client::load_sk( const string & sk_file_name)
{
	std::ifstream sk_file(sk_file_name.c_str());
	if (! sk_file)
	{
		throw FileNotFoundException(sk_file_name);
	}
	sk_file >>  sk;
	sk_file.close();
}

void Client::load_sk( const string & sk_file_name, string password)
{
	FILE * sk_file =  fopen(sk_file_name.c_str(), "r");

	if (sk_file  == NULL)
	{
		throw FileNotFoundException(sk_file_name);
	}
	decrypt_sk(sk_file, password);
	fclose(sk_file);
}
void Client:: decrypt_sk(FILE * sk_file, string password)
{
	SHA1 sha;
	sha.Reset();
	unsigned int out[5];
	sha << password.c_str();
	sha.Result(out);
	unsigned char  key[KEY_LENGTH];
	memcpy(key, out, KEY_LENGTH);
	aes_context * ctx = new aes_context[KEY_LENGTH];
	aes_set_key(ctx, key, KEY_LENGTH * NUMBER_OF_BITS);

	fseek(sk_file, 0, SEEK_END);
	int length = ftell(sk_file);
	fseek(sk_file, 0, SEEK_SET);

	size_t number_of_blocks = length / BLOCK_LENGTH;
	unsigned char encrypted_sk_block[BLOCK_LENGTH];
	unsigned char decrypted_sk_block[BLOCK_LENGTH];
	unsigned char decrypted_sk[length];

	for (size_t i = 0; i != number_of_blocks; ++i)
	{
		fread(encrypted_sk_block, sizeof(char), BLOCK_LENGTH, sk_file);
		aes_decrypt(ctx, encrypted_sk_block, decrypted_sk_block);
		memcpy(decrypted_sk, & decrypted_sk_block[i * BLOCK_LENGTH], BLOCK_LENGTH);
	}
	long sk_length;
	memcpy(&sk_length, decrypted_sk,  sizeof(long));
	/*unsigned char byte_representation[20];
	memcpy(byte_representation, & decrypted_sk[sizeof(long)] , sk_length	);*/
	sk = ZZFromBytes(& decrypted_sk[sizeof(long)] , sk_length);
}

unsigned char * Client::gen_aes_key()
{
	unsigned char * key = new unsigned char[KEY_LENGTH];
	for (size_t i = 0; i != KEY_LENGTH; ++i)
	{
		int x =rand();
		key[i] = (unsigned char) (x % 256);
	}
	return key;
}

Pair Client::encrypt_bit(int x)
{
	if (x == 0) x = -1;
	/*if sqr(sk) == pk,
	receiver will use s1 to decrypt;
	if sqr(sk) == pk, he will use s2 */

	ZZ t = find_Jacobi_eq(x, mpk);
	ZZ invt;
	assert(inv(t, mpk, invt));

	ZZ s1 = (t + pk * invt) % mpk;


	t = find_Jacobi_eq(x, mpk);
	assert(inv(t, mpk, invt));
	ZZ s2 = (pk * invt) % mpk;
	s2 = t-s2;
	s2 += mpk;
	s2 %= mpk;
	return Pair(s1, s2);
}


Pair Client::encrypt_bit(int x, bool& correct)
{
	if (x == 0) x = -1;

	ZZ t = find_Jacobi_eq(x, mpk);
	ZZ invt;
	assert(inv(t, mpk, invt));
	if ( gcd( 1 + sk * invt , mpk) != 1) correct = false;
	else correct = true;

	ZZ s1 = (t + pk * invt) % mpk;

	t = find_Jacobi_eq(x, mpk);
	assert(inv(t, mpk, invt));
	if ( gcd( 1 + sk * invt , mpk) != 1) correct = false;
	else correct = true;
	ZZ s2 = (pk * invt) % mpk;
	s2 = t-s2;
	s2 += mpk;
	s2 %= mpk;
	return Pair(s1, s2);
}


int Client::decrypt_bit(const Pair & p)
{
	int x = 0;

	if ( sk * sk % mpk == pk)
	{
		x = calc_Jacobi(p.first + 2*sk, mpk);
	}

	else
	{
		x = calc_Jacobi(p.second + 2*sk, mpk);
	}

	if (x == -1) x = 0;
	return x;
}


void Client::encrypt(const string & message_file_name, const  string & ciphertext_file_name)
{
	FILE *  message_file = fopen(message_file_name.c_str(), "r");
	FILE * ciphertext_file = fopen(ciphertext_file_name.c_str(), "w");
	if (message_file  == NULL)
	{
		throw FileNotFoundException(message_file_name) ;
	}

	unsigned char * key =  gen_aes_key();
	unsigned int out[5];

	//Encrypt key for AEs

	unsigned char byte_representation[20];
	for (size_t i = 0; i != KEY_LENGTH; ++i)
	{
		unsigned char val = key[i];
	  	for (size_t j = 0; j != NUMBER_OF_BITS; ++j)
		{
			int bit = get(val, j);
			Pair p = encrypt_bit(bit);

			long number_size =(p.first). size() * 8;
			fwrite(&number_size, sizeof(long), 1, ciphertext_file);
			BytesFromZZ(byte_representation, p.first, number_size);
			fwrite(byte_representation, sizeof(char),number_size, ciphertext_file);

			number_size =(p.second). size() * 8;
			fwrite(&number_size, sizeof(long), 1, ciphertext_file);
			BytesFromZZ(byte_representation, p.second, number_size);
			fwrite(byte_representation, sizeof(char),number_size, ciphertext_file);

		}
	}


	fseek(message_file, 0, SEEK_END);
	int message_length = ftell(message_file);
	fseek(message_file, 0, SEEK_SET);


	//Encrypt message using AEs

	aes_context * ctx = new aes_context[KEY_LENGTH];
	aes_set_key(ctx, key, KEY_LENGTH * NUMBER_OF_BITS);
	delete[] key;

	SHA1  sha;
	sha.Reset();

	unsigned char message_block [BLOCK_LENGTH];
	unsigned char encrypted_message_block[BLOCK_LENGTH];

	memcpy(message_block, & message_length, sizeof(int));

	size_t bytes_readen = fread( &  message_block[sizeof(int)],  sizeof(char),  BLOCK_LENGTH  - sizeof(int),  message_file)   +	sizeof(int);
	for (unsigned char * q = & message_block [ bytes_readen ] ; 	q  != &message_block[BLOCK_LENGTH];  ++q ) *q = (char) rand();

	sha.Input(& message_block[sizeof(int)],	 BLOCK_LENGTH  -  sizeof(int) );
	//sha.Result(out);

	aes_encrypt(ctx, message_block, encrypted_message_block);
	fwrite(encrypted_message_block, sizeof(char), BLOCK_LENGTH,  ciphertext_file);


	while (! feof(message_file))
	{
		size_t bytes_readen = fread(message_block, sizeof(char),BLOCK_LENGTH,message_file);
		for (size_t i = bytes_readen;  i <  BLOCK_LENGTH; ++i)  message_block[i] = (char)  rand();

		sha.Input(& message_block [0], BLOCK_LENGTH );
		//sha.Result(out);

		aes_encrypt(ctx, message_block, encrypted_message_block);
		fwrite(encrypted_message_block, sizeof(char),BLOCK_LENGTH, ciphertext_file);

	}
	sha.Result(out);

	int number_of_blocks = sizeof(int) * HASH_LENGTH/BLOCK_LENGTH;
	if (sizeof(int) * HASH_LENGTH % BLOCK_LENGTH)  ++number_of_blocks;

	unsigned char hashed_message[BLOCK_LENGTH*number_of_blocks];
	memcpy(hashed_message, out, sizeof(int) * HASH_LENGTH);

	for (size_t i  =  sizeof(int) * HASH_LENGTH; i != BLOCK_LENGTH * number_of_blocks; ++ i )
	{
		hashed_message [ i ] = (unsigned char) rand();
	}
	unsigned char hashed_message_block [BLOCK_LENGTH];
	unsigned char encrypted_hashed_message_block[BLOCK_LENGTH];

	for (size_t i = 0; i != number_of_blocks; ++i)
	{
		memcpy(hashed_message_block, & hashed_message[ i * BLOCK_LENGTH], BLOCK_LENGTH);
		aes_encrypt(ctx, hashed_message_block, encrypted_hashed_message_block);
		fwrite(encrypted_hashed_message_block, sizeof(char), BLOCK_LENGTH, ciphertext_file);
		/*for (size_t i = 0; i != BLOCK_LENGTH; ++i)
		{
			std:: cout <<(int)  hashed_message_block[i] << " ";
		}*/
	}

	fclose(message_file);
	fclose(ciphertext_file);

	delete[] ctx;
}


void Client::decrypt(const string & ciphertext_file_name,  const string & decrypted_message_file_name)
{
	FILE * ciphertext_file = fopen(ciphertext_file_name.c_str(), "r");
	FILE * decrypted_message_file = fopen(decrypted_message_file_name.c_str(), "w");
	if (ciphertext_file  == NULL)
	{
		throw FileNotFoundException(ciphertext_file_name) ;
	}

	//Decrypt AES key

	int number_size = 0;
	unsigned char key[KEY_LENGTH];
	unsigned char byte_representation [20];
	for (size_t i = 0; i !=  KEY_LENGTH; ++i)
	{
		unsigned char decrypted_char = 0;
		unsigned char c = 128;
		for (size_t j = 0; j!= NUMBER_OF_BITS; ++j)
		{
			fread(&number_size, sizeof(long), 1, ciphertext_file);
			fread(byte_representation, sizeof(char),number_size, ciphertext_file);
			ZZ s_1 = ZZFromBytes(byte_representation, number_size);
			fread(&number_size, sizeof(long), 1, ciphertext_file);

			fread(byte_representation, sizeof(char),number_size, ciphertext_file);
			ZZ s_2 = ZZFromBytes(byte_representation, number_size);
			Pair p = Pair(s_1, s_2);
			int bit = decrypt_bit(p);
			decrypted_char += c*bit;
			c /= 2;
		}
		key[i] = decrypted_char;
	}


	aes_context * ctx = new aes_context[KEY_LENGTH];
	aes_set_key(ctx, key, KEY_LENGTH * NUMBER_OF_BITS);

	SHA1  sha;
	sha.Reset();
	unsigned int out[HASH_LENGTH];
	unsigned int out_2[HASH_LENGTH];


	//Decrypt message using AES key

	unsigned char  encrypted_message[BLOCK_LENGTH];
	unsigned char decrypted_message[BLOCK_LENGTH];

	fread(encrypted_message,sizeof(char),BLOCK_LENGTH, ciphertext_file);
	aes_decrypt(ctx, encrypted_message, decrypted_message);

	sha.Input(& decrypted_message[sizeof(int)], BLOCK_LENGTH - sizeof(int));

	int message_length;
	memcpy(&message_length, decrypted_message, sizeof(int));

	int first_BLOCK_LENGTH = sizeof(int) + message_length;
	// if message consists of 1 block
	if (first_BLOCK_LENGTH > BLOCK_LENGTH) first_BLOCK_LENGTH = BLOCK_LENGTH;
	fwrite(& decrypted_message[sizeof(int)], sizeof(char), first_BLOCK_LENGTH - sizeof(int) , decrypted_message_file);

	int remained_message_length =  message_length - (BLOCK_LENGTH - sizeof (int));
	if (remained_message_length < 0) remained_message_length = 0;

	size_t number_of_blocks = remained_message_length / BLOCK_LENGTH;
	size_t tail = remained_message_length % BLOCK_LENGTH;

	for (size_t i = 0; i != number_of_blocks; ++i)
	{
		fread(encrypted_message,sizeof(char),BLOCK_LENGTH, ciphertext_file);
		aes_decrypt(ctx, encrypted_message, decrypted_message);
		sha.Input(& decrypted_message[0],  BLOCK_LENGTH );

		fwrite(decrypted_message, sizeof(char), BLOCK_LENGTH, decrypted_message_file);

	}
	if (tail != 0)
	{
		fread(encrypted_message,sizeof(char),BLOCK_LENGTH, ciphertext_file);
		aes_decrypt(ctx, encrypted_message, decrypted_message);

		sha.Input(& decrypted_message[0],  BLOCK_LENGTH );

		fwrite(decrypted_message, sizeof(char), tail, decrypted_message_file);
	}
	sha.Result(out);

	number_of_blocks = (sizeof(int) * HASH_LENGTH) / BLOCK_LENGTH;
	if (sizeof(int) * HASH_LENGTH  % BLOCK_LENGTH) 	++number_of_blocks;

	unsigned char hashed_message[BLOCK_LENGTH * number_of_blocks];
	unsigned char encrypted_hashed_message_block [BLOCK_LENGTH];
	unsigned char  hashed_message_block [BLOCK_LENGTH];

	for (size_t i = 0; i != number_of_blocks; ++i)
	{
		fread(&encrypted_hashed_message_block[0], sizeof(char), BLOCK_LENGTH, ciphertext_file);
		aes_decrypt(ctx, encrypted_hashed_message_block, hashed_message_block);

		/*for (size_t j = 0; j != BLOCK_LENGTH; ++j)
		{
			std:: cout <<(int)  hashed_message_block[j] << " ";
		}*/

		memcpy(& hashed_message [ i * BLOCK_LENGTH], hashed_message_block, BLOCK_LENGTH);
	}
	memcpy (out_2, hashed_message, sizeof(int) * HASH_LENGTH);


	bool corrupted = false;
	for (size_t i = 0 ; i != HASH_LENGTH ; ++i)
	{
		if (out[ i ] != out_2 [ i ]) corrupted = true;
	}

	if (corrupted) std::cout << "\nWarning!\nMessage has been altered!\n";
	fclose(ciphertext_file);
	fclose(decrypted_message_file);

	delete[] ctx;

}

